# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import ClassVar, List, overload

import numpy

class AtomicType:
    __doc__: ClassVar[str] = ...  # read-only
    __members__: ClassVar[dict] = ...  # read-only
    BOOLEAN: ClassVar[AtomicType] = ...
    NATURAL: ClassVar[AtomicType] = ...
    NEG_REAL: ClassVar[AtomicType] = ...
    POS_REAL: ClassVar[AtomicType] = ...
    PROBABILITY: ClassVar[AtomicType] = ...
    REAL: ClassVar[AtomicType] = ...
    __entries: ClassVar[dict] = ...
    def __init__(self, value: int) -> None: ...
    def __eq__(self, other: object) -> bool: ...
    def __getstate__(self) -> int: ...
    def __hash__(self) -> int: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...
    def __ne__(self, other: object) -> bool: ...
    def __setstate__(self, state: int) -> None: ...
    @property
    def name(self) -> str: ...
    @property
    def value(self) -> int: ...

class DistributionType:
    __doc__: ClassVar[str] = ...  # read-only
    __members__: ClassVar[dict] = ...  # read-only
    BERNOULLI: ClassVar[DistributionType] = ...
    BERNOULLI_LOGIT: ClassVar[DistributionType] = ...
    BERNOULLI_NOISY_OR: ClassVar[DistributionType] = ...
    BETA: ClassVar[DistributionType] = ...
    BIMIXTURE: ClassVar[DistributionType] = ...
    BINOMIAL: ClassVar[DistributionType] = ...
    CATEGORICAL: ClassVar[DistributionType] = ...
    DIRICHLET: ClassVar[DistributionType] = ...
    FLAT: ClassVar[DistributionType] = ...
    GAMMA: ClassVar[DistributionType] = ...
    HALF_CAUCHY: ClassVar[DistributionType] = ...
    HALF_NORMAL: ClassVar[DistributionType] = ...
    NORMAL: ClassVar[DistributionType] = ...
    STUDENT_T: ClassVar[DistributionType] = ...
    TABULAR: ClassVar[DistributionType] = ...
    __entries: ClassVar[dict] = ...
    def __init__(self, value: int) -> None: ...
    def __eq__(self, other: object) -> bool: ...
    def __getstate__(self) -> int: ...
    def __hash__(self) -> int: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...
    def __ne__(self, other: object) -> bool: ...
    def __setstate__(self, state: int) -> None: ...
    @property
    def name(self) -> str: ...
    @property
    def value(self) -> int: ...

class FactorType:
    __doc__: ClassVar[str] = ...  # read-only
    __members__: ClassVar[dict] = ...  # read-only
    EXP_PRODUCT: ClassVar[FactorType] = ...
    __entries: ClassVar[dict] = ...
    def __init__(self, value: int) -> None: ...
    def __eq__(self, other: object) -> bool: ...
    def __getstate__(self) -> int: ...
    def __hash__(self) -> int: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...
    def __ne__(self, other: object) -> bool: ...
    def __setstate__(self, state: int) -> None: ...
    @property
    def name(self) -> str: ...
    @property
    def value(self) -> int: ...

class Graph:
    def __init__(self) -> None: ...
    @overload
    def add_constant(self, value: bool) -> int: ...
    @overload
    def add_constant(self, value: float) -> int: ...
    @overload
    def add_constant(self, value: int) -> int: ...
    @overload
    def add_constant(self, value: NodeValue) -> int: ...
    def add_constant_bool(self, value: bool) -> int: ...
    def add_constant_bool_matrix(self, value: numpy.ndarray[bool[m, n]]) -> int: ...
    def add_constant_col_simplex_matrix(
        self, value: numpy.ndarray[numpy.float64[m, n]]
    ) -> int: ...
    def add_constant_natural(self, value: int) -> int: ...
    def add_constant_natural_matrix(
        self, value: numpy.ndarray[numpy.uint64[m, n]]
    ) -> int: ...
    def add_constant_neg_matrix(
        self, value: numpy.ndarray[numpy.float64[m, n]]
    ) -> int: ...
    def add_constant_neg_real(self, value: float) -> int: ...
    def add_constant_pos_matrix(
        self, value: numpy.ndarray[numpy.float64[m, n]]
    ) -> int: ...
    def add_constant_pos_real(self, value: float) -> int: ...
    def add_constant_probability(self, value: float) -> int: ...
    def add_constant_probability_matrix(
        self, value: numpy.ndarray[numpy.float64[m, n]]
    ) -> int: ...
    def add_constant_real(self, value: float) -> int: ...
    def add_constant_real_matrix(
        self, value: numpy.ndarray[numpy.float64[m, n]]
    ) -> int: ...
    @overload
    def add_distribution(
        self, dist_type: DistributionType, sample_type: AtomicType, parents: List[int]
    ) -> int: ...
    @overload
    def add_distribution(
        self, dist_type: DistributionType, sample_type: ValueType, parents: List[int]
    ) -> int: ...
    def add_factor(self, fac_type: FactorType, parents: List[int]) -> int: ...
    def add_operator(self, op: OperatorType, parents: List[int]) -> int: ...
    def collect_performance_data(self, b: bool) -> None: ...
    def customize_transformation(
        self, transform_type: TransformType, node_ids: List[int]
    ) -> None: ...
    def get_elbo(self) -> List[float]: ...
    def get_log_prob(self) -> List[List[float]]: ...
    @overload
    def infer(
        self, num_samples: int, algorithm: InferenceType = ..., seed: int = ...
    ) -> List[List[NodeValue]]: ...
    @overload
    def infer(
        self,
        num_samples: int,
        algorithm: InferenceType = ...,
        seed: int = ...,
        n_chains: int = ...,
        infer_config: InferConfig = ...,
    ) -> List[List[List[NodeValue]]]: ...
    @overload
    def infer_mean(
        self, num_samples: int, algorithm: InferenceType = ..., seed: int = ...
    ) -> List[float]: ...
    @overload
    def infer_mean(
        self,
        num_samples: int,
        algorithm: InferenceType = ...,
        seed: int = ...,
        n_chains: int = ...,
        infer_config: InferConfig = ...,
    ) -> List[List[float]]: ...
    @overload
    def observe(self, node_id: int, val: bool) -> None: ...
    @overload
    def observe(self, node_id: int, val: float) -> None: ...
    @overload
    def observe(self, node_id: int, val: int) -> None: ...
    @overload
    def observe(
        self, node_id: int, val: numpy.ndarray[numpy.float64[m, n]]
    ) -> None: ...
    @overload
    def observe(self, node_id: int, val: numpy.ndarray[bool[m, n]]) -> None: ...
    @overload
    def observe(self, node_id: int, val: numpy.ndarray[numpy.uint64[m, n]]) -> None: ...
    @overload
    def observe(self, node_id: int, val: NodeValue) -> None: ...
    def performance_report(self) -> str: ...
    def query(self, node_id: int) -> int: ...
    def remove_observations(self) -> None: ...
    def to_dot(self) -> str: ...
    def to_string(self) -> str: ...
    def variational(
        self,
        num_iters: int,
        steps_per_iter: int,
        seed: int = ...,
        elbo_samples: int = ...,
    ) -> List[List[float]]: ...

class HMC:
    def __init__(self, arg0: Graph, arg1: float, arg2: float) -> None: ...
    def infer(
        self,
        num_samples: int,
        seed: int,
        num_warmup_samples: int = ...,
        save_warmup: bool = ...,
        init_type: InitType = ...,
    ) -> List[List[NodeValue]]: ...

class InferConfig:
    keep_log_prob: bool
    keep_warmup: bool
    num_warmup: int
    path_length: float
    step_size: float
    @overload
    def __init__(self) -> None: ...
    @overload
    def __init__(
        self, arg0: bool, arg1: float, arg2: float, arg3: int, arg4: bool
    ) -> None: ...

class InferenceType:
    __doc__: ClassVar[str] = ...  # read-only
    __members__: ClassVar[dict] = ...  # read-only
    GIBBS: ClassVar[InferenceType] = ...
    NMC: ClassVar[InferenceType] = ...
    REJECTION: ClassVar[InferenceType] = ...
    __entries: ClassVar[dict] = ...
    def __init__(self, value: int) -> None: ...
    def __eq__(self, other: object) -> bool: ...
    def __getstate__(self) -> int: ...
    def __hash__(self) -> int: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...
    def __ne__(self, other: object) -> bool: ...
    def __setstate__(self, state: int) -> None: ...
    @property
    def name(self) -> str: ...
    @property
    def value(self) -> int: ...

class InitType:
    __doc__: ClassVar[str] = ...  # read-only
    __members__: ClassVar[dict] = ...  # read-only
    PRIOR: ClassVar[InitType] = ...
    RANDOM: ClassVar[InitType] = ...
    ZERO: ClassVar[InitType] = ...
    __entries: ClassVar[dict] = ...
    def __init__(self, value: int) -> None: ...
    def __eq__(self, other: object) -> bool: ...
    def __getstate__(self) -> int: ...
    def __hash__(self) -> int: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...
    def __ne__(self, other: object) -> bool: ...
    def __setstate__(self, state: int) -> None: ...
    @property
    def name(self) -> str: ...
    @property
    def value(self) -> int: ...

class NUTS:
    def __init__(self, arg0: Graph) -> None: ...
    def infer(
        self,
        num_samples: int,
        seed: int,
        num_warmup_samples: int = ...,
        save_warmup: bool = ...,
        init_type: InitType = ...,
    ) -> List[List[NodeValue]]: ...

class Node:
    def __init__(self, *args, **kwargs) -> None: ...

class NodeType:
    __doc__: ClassVar[str] = ...  # read-only
    __members__: ClassVar[dict] = ...  # read-only
    CONSTANT: ClassVar[NodeType] = ...
    DISTRIBUTION: ClassVar[NodeType] = ...
    FACTOR: ClassVar[NodeType] = ...
    OPERATOR: ClassVar[NodeType] = ...
    __entries: ClassVar[dict] = ...
    def __init__(self, value: int) -> None: ...
    def __eq__(self, other: object) -> bool: ...
    def __getstate__(self) -> int: ...
    def __hash__(self) -> int: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...
    def __ne__(self, other: object) -> bool: ...
    def __setstate__(self, state: int) -> None: ...
    @property
    def name(self) -> str: ...
    @property
    def value(self) -> int: ...

class NodeValue:
    @overload
    def __init__(self, arg0: bool) -> None: ...
    @overload
    def __init__(self, arg0: float) -> None: ...
    @overload
    def __init__(self, arg0: int) -> None: ...
    @overload
    def __init__(self, arg0: numpy.ndarray[bool[m, n]]) -> None: ...
    @overload
    def __init__(self, arg0: numpy.ndarray[numpy.float64[m, n]]) -> None: ...

class OperatorType:
    __doc__: ClassVar[str] = ...  # read-only
    __members__: ClassVar[dict] = ...  # read-only
    ADD: ClassVar[OperatorType] = ...
    BROADCAST: ClassVar[OperatorType] = ...
    BROADCAST_ADD: ClassVar[OperatorType] = ...
    CHOICE: ClassVar[OperatorType] = ...
    COLUMN_INDEX: ClassVar[OperatorType] = ...
    COMPLEMENT: ClassVar[OperatorType] = ...
    EXP: ClassVar[OperatorType] = ...
    EXPM1: ClassVar[OperatorType] = ...
    FILL_MATRIX: ClassVar[OperatorType] = ...
    IF_THEN_ELSE: ClassVar[OperatorType] = ...
    INDEX: ClassVar[OperatorType] = ...
    LOG: ClassVar[OperatorType] = ...
    LOG1MEXP: ClassVar[OperatorType] = ...
    LOG1PEXP: ClassVar[OperatorType] = ...
    LOGISTIC: ClassVar[OperatorType] = ...
    LOGSUMEXP: ClassVar[OperatorType] = ...
    LOGSUMEXP_VECTOR: ClassVar[OperatorType] = ...
    MATRIX_MULTIPLY: ClassVar[OperatorType] = ...
    MATRIX_SCALE: ClassVar[OperatorType] = ...
    MULTIPLY: ClassVar[OperatorType] = ...
    NEGATE: ClassVar[OperatorType] = ...
    PHI: ClassVar[OperatorType] = ...
    POW: ClassVar[OperatorType] = ...
    SAMPLE: ClassVar[OperatorType] = ...
    TO_MATRIX: ClassVar[OperatorType] = ...
    TO_NEG_REAL: ClassVar[OperatorType] = ...
    TO_POS_REAL: ClassVar[OperatorType] = ...
    TO_POS_REAL_MATRIX: ClassVar[OperatorType] = ...
    TO_PROBABILITY: ClassVar[OperatorType] = ...
    TO_REAL: ClassVar[OperatorType] = ...
    TO_REAL_MATRIX: ClassVar[OperatorType] = ...
    __entries: ClassVar[dict] = ...
    def __init__(self, value: int) -> None: ...
    def __eq__(self, other: object) -> bool: ...
    def __getstate__(self) -> int: ...
    def __hash__(self) -> int: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...
    def __ne__(self, other: object) -> bool: ...
    def __setstate__(self, state: int) -> None: ...
    @property
    def name(self) -> str: ...
    @property
    def value(self) -> int: ...

class TransformType:
    __doc__: ClassVar[str] = ...  # read-only
    __members__: ClassVar[dict] = ...  # read-only
    LOG: ClassVar[TransformType] = ...
    NONE: ClassVar[TransformType] = ...
    __entries: ClassVar[dict] = ...
    def __init__(self, value: int) -> None: ...
    def __eq__(self, other: object) -> bool: ...
    def __getstate__(self) -> int: ...
    def __hash__(self) -> int: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...
    def __ne__(self, other: object) -> bool: ...
    def __setstate__(self, state: int) -> None: ...
    @property
    def name(self) -> str: ...
    @property
    def value(self) -> int: ...

class ValueType:
    def __init__(
        self, arg0: VariableType, arg1: AtomicType, arg2: int, arg3: int
    ) -> None: ...
    def to_string(self) -> str: ...

class VariableType:
    __doc__: ClassVar[str] = ...  # read-only
    __members__: ClassVar[dict] = ...  # read-only
    BROADCAST_MATRIX: ClassVar[VariableType] = ...
    COL_SIMPLEX_MATRIX: ClassVar[VariableType] = ...
    SCALAR: ClassVar[VariableType] = ...
    __entries: ClassVar[dict] = ...
    def __init__(self, value: int) -> None: ...
    def __eq__(self, other: object) -> bool: ...
    def __getstate__(self) -> int: ...
    def __hash__(self) -> int: ...
    def __index__(self) -> int: ...
    def __int__(self) -> int: ...
    def __ne__(self, other: object) -> bool: ...
    def __setstate__(self, state: int) -> None: ...
    @property
    def name(self) -> str: ...
    @property
    def value(self) -> int: ...
